// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

//go:build !js
// +build !js

package webrtc

import (
	"encoding/binary"
	"fmt"
	"io"
	"math"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pion/interceptor"
	"github.com/pion/interceptor/pkg/stats"
	"github.com/pion/logging"
	"github.com/pion/rtcp"
	"github.com/pion/srtp/v3"
	"github.com/pion/webrtc/v4/internal/util"
)

// trackStreams maintains a mapping of RTP/RTCP streams to a specific track
// a RTPReceiver may contain multiple streams if we are dealing with Simulcast.
type trackStreams struct {
	track *TrackRemote

	streamInfo, repairStreamInfo *interceptor.StreamInfo

	rtpReadStream  *srtp.ReadStreamSRTP
	rtpInterceptor interceptor.RTPReader

	rtcpReadStream  *srtp.ReadStreamSRTCP
	rtcpInterceptor interceptor.RTCPReader

	repairReadStream    *srtp.ReadStreamSRTP
	repairInterceptor   interceptor.RTPReader
	repairStreamChannel chan rtxPacketWithAttributes

	repairRtcpReadStream  *srtp.ReadStreamSRTCP
	repairRtcpInterceptor interceptor.RTCPReader
}

type rtxPacketWithAttributes struct {
	pkt        []byte
	attributes interceptor.Attributes
	pool       *sync.Pool
}

func (p *rtxPacketWithAttributes) release() {
	if p.pkt != nil {
		b := p.pkt[:cap(p.pkt)]
		p.pool.Put(b) // nolint:staticcheck
		p.pkt = nil
	}
}

// RTPReceiver allows an application to inspect the receipt of a TrackRemote.
type RTPReceiver struct {
	kind      RTPCodecType
	transport *DTLSTransport

	tracks []trackStreams

	closed               atomic.Bool
	closedChan, received chan any
	mu                   sync.RWMutex

	tr *RTPTransceiver

	// A reference to the associated api object
	api *API

	rtxPool sync.Pool

	log logging.LeveledLogger
}

// NewRTPReceiver constructs a new RTPReceiver.
func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RTPReceiver, error) {
	if transport == nil {
		return nil, errRTPReceiverDTLSTransportNil
	}

	rtpReceiver := &RTPReceiver{
		kind:       kind,
		transport:  transport,
		api:        api,
		closedChan: make(chan any),
		received:   make(chan any),
		tracks:     []trackStreams{},
		rtxPool: sync.Pool{New: func() any {
			return make([]byte, api.settingEngine.getReceiveMTU())
		}},
		log: api.settingEngine.LoggerFactory.NewLogger("RTPReceiver"),
	}

	return rtpReceiver, nil
}

func (r *RTPReceiver) setRTPTransceiver(tr *RTPTransceiver) {
	r.mu.Lock()
	defer r.mu.Unlock()
	r.tr = tr
}

// Transport returns the currently-configured *DTLSTransport or nil
// if one has not yet been configured.
func (r *RTPReceiver) Transport() *DTLSTransport {
	r.mu.RLock()
	defer r.mu.RUnlock()

	return r.transport
}

func (r *RTPReceiver) getParameters() RTPParameters {
	parameters := r.api.mediaEngine.getRTPParametersByKind(
		r.kind,
		[]RTPTransceiverDirection{RTPTransceiverDirectionRecvonly},
	)
	if r.tr != nil {
		parameters.Codecs = r.tr.getCodecs()
	}

	return parameters
}

// GetParameters describes the current configuration for the encoding and
// transmission of media on the receiver's track.
func (r *RTPReceiver) GetParameters() RTPParameters {
	r.mu.RLock()
	defer r.mu.RUnlock()

	return r.getParameters()
}

// Track returns the RtpTransceiver TrackRemote.
func (r *RTPReceiver) Track() *TrackRemote {
	r.mu.RLock()
	defer r.mu.RUnlock()

	if len(r.tracks) != 1 {
		return nil
	}

	return r.tracks[0].track
}

// Tracks returns the RtpTransceiver tracks
// A RTPReceiver to support Simulcast may now have multiple tracks.
func (r *RTPReceiver) Tracks() []*TrackRemote {
	r.mu.RLock()
	defer r.mu.RUnlock()

	var tracks []*TrackRemote
	for i := range r.tracks {
		tracks = append(tracks, r.tracks[i].track)
	}

	return tracks
}

// RTPTransceiver returns the RTPTransceiver this
// RTPReceiver belongs too, or nil if none.
func (r *RTPReceiver) RTPTransceiver() *RTPTransceiver {
	r.mu.Lock()
	defer r.mu.Unlock()

	return r.tr
}

// configureReceive initialize the track.
func (r *RTPReceiver) configureReceive(parameters RTPReceiveParameters) {
	r.mu.Lock()
	defer r.mu.Unlock()

	for i := range parameters.Encodings {
		t := trackStreams{
			track: newTrackRemote(
				r.kind,
				parameters.Encodings[i].SSRC,
				parameters.Encodings[i].RTX.SSRC,
				parameters.Encodings[i].RID,
				r,
			),
		}

		r.tracks = append(r.tracks, t)
	}
}

// startReceive starts all the transports.
func (r *RTPReceiver) startReceive(parameters RTPReceiveParameters) error { //nolint:cyclop
	r.mu.Lock()
	defer r.mu.Unlock()
	select {
	case <-r.received:
		return errRTPReceiverReceiveAlreadyCalled
	default:
	}

	globalParams := r.getParameters()
	codec := RTPCodecCapability{}
	if len(globalParams.Codecs) != 0 {
		codec = globalParams.Codecs[0].RTPCodecCapability
	}

	for i := range parameters.Encodings {
		if parameters.Encodings[i].RID != "" {
			// RID based tracks will be set up in receiveForRid
			continue
		}

		var streams *trackStreams
		for idx, ts := range r.tracks {
			if ts.track != nil && ts.track.SSRC() == parameters.Encodings[i].SSRC {
				streams = &r.tracks[idx]

				break
			}
		}
		if streams == nil {
			return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, parameters.Encodings[i].SSRC)
		}

		streams.streamInfo = createStreamInfo(
			"",
			parameters.Encodings[i].SSRC,
			0, 0, 0, 0, 0,
			codec,
			globalParams.HeaderExtensions,
		)

		result, err := r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *streams.streamInfo)
		if err != nil {
			return err
		}
		streams.rtpReadStream = result.rtpReadStream
		streams.rtpInterceptor = result.rtpInterceptor
		streams.rtcpReadStream = result.rtcpReadStream
		streams.rtcpInterceptor = result.rtcpInterceptor

		if rtxSsrc := parameters.Encodings[i].RTX.SSRC; rtxSsrc != 0 {
			streamInfo := createStreamInfo("", rtxSsrc, 0, 0, 0, 0, 0, codec, globalParams.HeaderExtensions)
			result, err = r.transport.streamsForSSRC(
				rtxSsrc,
				*streamInfo,
			)
			if err != nil {
				return err
			}
			rtpReadStream := result.rtpReadStream
			rtpInterceptor := result.rtpInterceptor
			rtcpReadStream := result.rtcpReadStream
			rtcpInterceptor := result.rtcpInterceptor

			if err = r.receiveForRtx(
				rtxSsrc,
				"",
				streamInfo,
				rtpReadStream,
				rtpInterceptor,
				rtcpReadStream,
				rtcpInterceptor,
			); err != nil {
				return err
			}
		}
	}

	close(r.received)

	return nil
}

// Receive initialize the track and starts all the transports.
func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
	r.configureReceive(parameters)

	return r.startReceive(parameters)
}

// Read reads incoming RTCP for this RTPReceiver.
func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error) {
	select {
	case <-r.received:
		if len(r.tracks) > 1 {
			r.log.Errorf(useReadSimulcast)
		}

		return r.tracks[0].rtcpInterceptor.Read(b, a)
	case <-r.closedChan:
		return 0, nil, io.ErrClosedPipe
	}
}

// ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid.
func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) {
	select {
	case <-r.received:
		var rtcpInterceptor interceptor.RTCPReader

		r.mu.Lock()
		for _, t := range r.tracks {
			if t.track != nil && t.track.rid == rid {
				rtcpInterceptor = t.rtcpInterceptor
			}
		}
		r.mu.Unlock()

		if rtcpInterceptor == nil {
			return 0, nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
		}

		return rtcpInterceptor.Read(b, a)

	case <-r.closedChan:
		return 0, nil, io.ErrClosedPipe
	}
}

// ReadRTCP is a convenience method that wraps Read and unmarshal for you.
// It also runs any configured interceptors.
func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
	b := make([]byte, r.api.settingEngine.getReceiveMTU())
	i, attributes, err := r.Read(b)
	if err != nil {
		return nil, nil, err
	}

	pkts, err := rtcp.Unmarshal(b[:i])
	if err != nil {
		return nil, nil, err
	}

	return pkts, attributes, nil
}

// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you.
func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) {
	b := make([]byte, r.api.settingEngine.getReceiveMTU())
	i, attributes, err := r.ReadSimulcast(b, rid)
	if err != nil {
		return nil, nil, err
	}

	pkts, err := rtcp.Unmarshal(b[:i])

	return pkts, attributes, err
}

func (r *RTPReceiver) haveReceived() bool {
	select {
	case <-r.received:
		return true
	default:
		return false
	}
}

func (r *RTPReceiver) haveClosed() bool {
	return r.closed.Load()
}

// Stop irreversibly stops the RTPReceiver.
func (r *RTPReceiver) Stop() error { //nolint:cyclop
	r.mu.Lock()
	defer r.mu.Unlock()
	var err error

	select {
	case <-r.closedChan:
		return err
	default:
	}

	select {
	case <-r.received:
		for i := range r.tracks {
			errs := []error{}

			if r.tracks[i].rtcpReadStream != nil {
				errs = append(errs, r.tracks[i].rtcpReadStream.Close())
			}

			if r.tracks[i].rtpReadStream != nil {
				errs = append(errs, r.tracks[i].rtpReadStream.Close())
			}

			if r.tracks[i].repairReadStream != nil {
				errs = append(errs, r.tracks[i].repairReadStream.Close())
			}

			if r.tracks[i].repairRtcpReadStream != nil {
				errs = append(errs, r.tracks[i].repairRtcpReadStream.Close())
			}

			if r.tracks[i].streamInfo != nil {
				r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo)
			}

			if r.tracks[i].repairStreamInfo != nil {
				r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo)
			}

			err = util.FlattenErrs(errs)
		}
	default:
	}

	close(r.closedChan)
	r.closed.Store(true)

	return err
}

func (r *RTPReceiver) collectStats(collector *statsReportCollector, statsGetter stats.Getter) {
	if statsGetter == nil {
		return
	}

	r.mu.Lock()
	defer r.mu.Unlock()

	// Emit inbound-rtp stats for each track
	mid := ""
	if r.tr != nil {
		mid = r.tr.Mid()
	}
	now := statsTimestampNow()
	nowTime := now.Time()
	for trackIndex := range r.tracks {
		remoteTrack := r.tracks[trackIndex].track
		if remoteTrack == nil {
			continue
		}

		collector.Collecting()

		inboundID := fmt.Sprintf("inbound-rtp-%d", uint32(remoteTrack.SSRC()))
		codecID := ""
		if remoteTrack.codec.statsID != "" {
			codecID = remoteTrack.codec.statsID
		}

		inboundStats := InboundRTPStreamStats{
			Mid:         mid,
			Timestamp:   now,
			Type:        StatsTypeInboundRTP,
			ID:          inboundID,
			SSRC:        remoteTrack.SSRC(),
			Kind:        r.kind.String(),
			TransportID: "iceTransport",
			CodecID:     codecID,
		}
		r.populateInboundStats(&inboundStats, statsGetter, remoteTrack)

		collector.Collect(inboundID, inboundStats)

		if remoteTrack.Kind() == RTPCodecTypeAudio {
			r.collectAudioPlayoutStats(collector, nowTime, remoteTrack)
		}
	}
}

func (r *RTPReceiver) populateInboundStats(
	inboundStats *InboundRTPStreamStats,
	statsGetter stats.Getter,
	remoteTrack *TrackRemote,
) {
	stats := statsGetter.Get(uint32(remoteTrack.SSRC()))
	if stats == nil {
		return
	}

	// Wrap-around casting by design, with warnings if overflow/underflow is detected.
	pr := stats.InboundRTPStreamStats.PacketsReceived
	if pr > math.MaxUint32 {
		r.log.Warnf("Inbound PacketsReceived exceeds uint32 and will wrap: %d", pr)
	}
	inboundStats.PacketsReceived = uint32(pr) //nolint:gosec

	pl := stats.InboundRTPStreamStats.PacketsLost
	if pl > math.MaxInt32 || pl < math.MinInt32 {
		r.log.Warnf("Inbound PacketsLost exceeds int32 range and will wrap: %d", pl)
	}
	inboundStats.PacketsLost = int32(pl) //nolint:gosec

	inboundStats.Jitter = stats.InboundRTPStreamStats.Jitter
	inboundStats.BytesReceived = stats.InboundRTPStreamStats.BytesReceived
	inboundStats.HeaderBytesReceived = stats.InboundRTPStreamStats.HeaderBytesReceived
	timestamp := stats.InboundRTPStreamStats.LastPacketReceivedTimestamp
	inboundStats.LastPacketReceivedTimestamp = StatsTimestamp(
		timestamp.UnixNano() / int64(time.Millisecond))
	inboundStats.FIRCount = stats.InboundRTPStreamStats.FIRCount
	inboundStats.PLICount = stats.InboundRTPStreamStats.PLICount
	inboundStats.NACKCount = stats.InboundRTPStreamStats.NACKCount
}

func (r *RTPReceiver) collectAudioPlayoutStats(
	collector *statsReportCollector,
	nowTime time.Time,
	remoteTrack *TrackRemote,
) {
	playoutStats := remoteTrack.pullAudioPlayoutStats(nowTime)
	for _, stats := range playoutStats {
		collector.Collecting()
		collector.Collect(stats.ID, stats)
	}
}

func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
	for i := range r.tracks {
		if r.tracks[i].track == t {
			return &r.tracks[i]
		}
	}

	return nil
}

// readRTP should only be called by a track, this only exists so we can keep state in one place.
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
	select {
	case <-r.received:
	case <-r.closedChan:
		return 0, nil, io.EOF
	}

	if t := r.streamsForTrack(reader); t != nil {
		return t.rtpInterceptor.Read(b, a)
	}

	return 0, nil, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
}

// receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs
// It populates all the internal state for the given RID.
func (r *RTPReceiver) receiveForRid(
	rid string,
	params RTPParameters,
	streamInfo *interceptor.StreamInfo,
	rtpReadStream *srtp.ReadStreamSRTP,
	rtpInterceptor interceptor.RTPReader,
	rtcpReadStream *srtp.ReadStreamSRTCP,
	rtcpInterceptor interceptor.RTCPReader,
	peekedPackets []*peekedPacket,
) (*TrackRemote, error) {
	r.mu.Lock()
	defer r.mu.Unlock()

	for i := range r.tracks {
		if r.tracks[i].track.RID() == rid {
			r.tracks[i].track.mu.Lock()
			r.tracks[i].track.kind = r.kind
			r.tracks[i].track.codec = params.Codecs[0]
			r.tracks[i].track.params = params
			r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
			r.tracks[i].track.peekedPackets = peekedPackets
			r.tracks[i].track.mu.Unlock()

			r.tracks[i].streamInfo = streamInfo
			r.tracks[i].rtpReadStream = rtpReadStream
			r.tracks[i].rtpInterceptor = rtpInterceptor
			r.tracks[i].rtcpReadStream = rtcpReadStream
			r.tracks[i].rtcpInterceptor = rtcpInterceptor

			return r.tracks[i].track, nil
		}
	}

	return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
}

// receiveForRtx starts a routine that processes the repair stream.
//
//nolint:cyclop
func (r *RTPReceiver) receiveForRtx(
	ssrc SSRC,
	rsid string,
	streamInfo *interceptor.StreamInfo,
	rtpReadStream *srtp.ReadStreamSRTP,
	rtpInterceptor interceptor.RTPReader,
	rtcpReadStream *srtp.ReadStreamSRTCP,
	rtcpInterceptor interceptor.RTCPReader,
) error {
	var track *trackStreams
	if ssrc != 0 && len(r.tracks) == 1 {
		track = &r.tracks[0]
	} else {
		for i := range r.tracks {
			if r.tracks[i].track.RID() == rsid {
				track = &r.tracks[i]
				if track.track.RtxSSRC() == 0 {
					track.track.setRtxSSRC(SSRC(streamInfo.SSRC))
				}

				break
			}
		}
	}

	if track == nil {
		return fmt.Errorf("%w: ssrc(%d) rsid(%s)", errRTPReceiverForRIDTrackStreamNotFound, ssrc, rsid)
	}

	track.repairStreamInfo = streamInfo
	track.repairReadStream = rtpReadStream
	track.repairInterceptor = rtpInterceptor
	track.repairRtcpReadStream = rtcpReadStream
	track.repairRtcpInterceptor = rtcpInterceptor
	track.repairStreamChannel = make(chan rtxPacketWithAttributes, 50)

	go func() {
		for {
			b := r.rtxPool.Get().([]byte) // nolint:forcetypeassert
			i, attributes, err := track.repairInterceptor.Read(b, nil)
			if err != nil {
				r.rtxPool.Put(b) // nolint:staticcheck

				return
			}

			// RTX packets have a different payload format. Move the OSN in the payload to the RTP header and rewrite the
			// payload type and SSRC, so that we can return RTX packets to the caller 'transparently' i.e. in the same format
			// as non-RTX RTP packets
			hasExtension := b[0]&0b10000 > 0
			hasPadding := b[0]&0b100000 > 0
			csrcCount := b[0] & 0b1111
			headerLength := uint16(12 + (4 * csrcCount))
			paddingLength := 0
			if hasExtension {
				headerLength += 4 * (1 + binary.BigEndian.Uint16(b[headerLength+2:headerLength+4]))
			}
			if hasPadding {
				paddingLength = int(b[i-1])
			}

			if i-int(headerLength)-paddingLength < 2 {
				// BWE probe packet, ignore
				r.rtxPool.Put(b) // nolint:staticcheck

				continue
			}

			if attributes == nil {
				attributes = make(interceptor.Attributes)
			}
			attributes.Set(AttributeRtxPayloadType, b[1]&0x7F)
			attributes.Set(AttributeRtxSequenceNumber, binary.BigEndian.Uint16(b[2:4]))
			attributes.Set(AttributeRtxSsrc, binary.BigEndian.Uint32(b[8:12]))

			b[1] = (b[1] & 0x80) | uint8(track.track.PayloadType())
			b[2] = b[headerLength]
			b[3] = b[headerLength+1]
			binary.BigEndian.PutUint32(b[8:12], uint32(track.track.SSRC()))
			copy(b[headerLength:i-2], b[headerLength+2:i])

			select {
			case <-r.closedChan:
				r.rtxPool.Put(b) // nolint:staticcheck

				return
			case track.repairStreamChannel <- rtxPacketWithAttributes{pkt: b[:i-2], attributes: attributes, pool: &r.rtxPool}:
			default:
				// skip the RTX packet if the repair stream channel is full, could be blocked in the application's read loop
			}
		}
	}()

	return nil
}

// SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever.
func (r *RTPReceiver) SetReadDeadline(t time.Time) error {
	r.mu.RLock()
	defer r.mu.RUnlock()

	return r.tracks[0].rtcpReadStream.SetReadDeadline(t)
}

// SetReadDeadlineSimulcast sets the max amount of time the RTCP stream for a given rid will block before returning.
// 0 is forever.
func (r *RTPReceiver) SetReadDeadlineSimulcast(deadline time.Time, rid string) error {
	r.mu.RLock()
	defer r.mu.RUnlock()

	for _, t := range r.tracks {
		if t.track != nil && t.track.rid == rid {
			return t.rtcpReadStream.SetReadDeadline(deadline)
		}
	}

	return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
}

// setRTPReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever.
// This should be fired by calling SetReadDeadline on the TrackRemote.
func (r *RTPReceiver) setRTPReadDeadline(deadline time.Time, reader *TrackRemote) error {
	r.mu.RLock()
	defer r.mu.RUnlock()

	if t := r.streamsForTrack(reader); t != nil {
		return t.rtpReadStream.SetReadDeadline(deadline)
	}

	return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
}

// readRTX returns an RTX packet if one is available on the RTX track, otherwise returns nil.
func (r *RTPReceiver) readRTX(reader *TrackRemote) *rtxPacketWithAttributes {
	if !reader.HasRTX() {
		return nil
	}

	select {
	case <-r.received:
	default:
		return nil
	}

	if t := r.streamsForTrack(reader); t != nil {
		select {
		case rtxPacketReceived := <-t.repairStreamChannel:
			return &rtxPacketReceived
		default:
		}
	}

	return nil
}
